
package nsalt.util

import chisel3._
import chisel3.util._


object WordShift {
  def apply(data: UInt, wordIndex: UInt, step: Int) = (data << (wordIndex * step.U))
}

object MaskExpand {
 def apply(m: UInt) = Cat(m.asBools.map(Fill(8, _)).reverse)
}

object MaskData {
 def apply(oldData: UInt, newData: UInt, fullmask: UInt) = {
   (newData & fullmask) | (oldData & ~fullmask)
 }
}

object SignExt {
  def apply(a: UInt, len: Int) = {
    val aLen = a.getWidth
    val signBit = a(aLen-1)
    if (aLen >= len) a(len-1,0) else Cat(Fill(len - aLen, signBit), a)
  }
}

object ZeroExt {
  def apply(a: UInt, len: Int) = {
    val aLen = a.getWidth
    if (aLen >= len) a(len-1,0) else Cat(0.U((len - aLen).W), a)
  }
}

object LookupTree {
  def apply[T <: Data](key: UInt, mapping: Iterable[(UInt, T)]): T =
    Mux1H(mapping.map(p => (p._1 === key, p._2)))
}

object LookupTreeDefault {
  def apply[T <: Data](key: UInt, default: T, mapping: Iterable[(UInt, T)]): T =
    MuxLookup(key, default, mapping.toSeq)
}

object PipeLink {
  def apply[T <: Data](left: DecoupledIO[T], right: DecoupledIO[T], rightOutFire: Bool, isFlush: Bool) = {

    val valid = RegInit(false.B)
    
    when (rightOutFire) {
      valid := false.B
    }
    when (left.valid && right.ready) { 
      valid := true.B 
    }
    when (isFlush) {
      valid := false.B
    }

    left.ready := right.ready
    right.bits := RegEnable(left.bits, left.valid && right.ready)
    right.valid := valid //&& !isFlush
  }
}

object RegMap {

  def Unwritable = null
  
  def apply(addr: Int, reg: UInt, wfn: UInt => UInt = (x => x)) = 
    (addr, (reg, wfn))

  def generate(
    mapping: Map[Int, (UInt, UInt => UInt)], 
    raddr: UInt, rdata: UInt,
    waddr: UInt, wdata: UInt, wen: Bool, wmask: UInt
  ):Unit = {

    val chiselMapping = mapping
      .map { case (a, (r, w)) => (a.U, r, w) }

    rdata := LookupTree(raddr, chiselMapping.map { case (a, r, w) => (a, r) })
    
    chiselMapping.map { case (a, r, w) =>
      if (w != null) when (wen && waddr === a) { 
        r := w(MaskData(r, wdata, wmask))
      }
    }
  }
  def generate(
    mapping: Map[Int, (UInt, UInt => UInt)],
    addr: UInt, rdata: UInt,
    wdata: UInt, wmask: UInt, wen: Bool
  ):Unit = 
    generate(mapping, addr, rdata, addr, wdata, wen, wmask)

}

object MaskedRegMap {
  def Unwritable = null
  def NoSideEffect: UInt => UInt = (x=>x)
  def WritableMask = Fill(if (true) 32 else 64, true.B)
  def UnwritableMask = 0.U(if (true) 32.W else 64.W)

  def apply(addr: Int, reg: UInt, wmask: UInt = WritableMask, wfn: UInt => UInt = (x => x), rmask: UInt = WritableMask) = (addr, (reg, wmask, wfn, rmask))
  
  def generate(mapping: Map[Int, (UInt, UInt, UInt => UInt, UInt)], raddr: UInt, rdata: UInt,
    waddr: UInt, wen: Bool, wdata: UInt):Unit = {
    val chiselMapping = mapping.map { case (a, (r, wm, w, rm)) => (a.U, r, wm, w, rm) }
    rdata := LookupTree(raddr, chiselMapping.map { case (a, r, wm, w, rm) => (a, r & rm) })
    chiselMapping.map { case (a, r, wm, w, rm) =>
      if (w != null && wm != UnwritableMask) when (wen && waddr === a) { r := w(MaskData(r, wdata, wm)) }
    }
  }
  
  def isIllegalAddr(mapping: Map[Int, (UInt, UInt, UInt => UInt, UInt)], addr: UInt):Bool = {
    val illegalAddr = Wire(Bool())
    val chiselMapping = mapping.map { case (a, (r, wm, w, rm)) => (a.U, r, wm, w, rm) }
    illegalAddr := LookupTreeDefault(addr, true.B, chiselMapping.map { case (a, r, wm, w, rm) => (a, false.B) })
    illegalAddr
  }
  
  def generate(mapping: Map[Int, (UInt, UInt, UInt => UInt, UInt)], addr: UInt, rdata: UInt,
    wen: Bool, wdata: UInt):Unit = generate(mapping, addr, rdata, addr, wen, wdata)
}


object LFSR64 { 
  def apply(increment: Bool = true.B): UInt = { 
    val wide = 64
    val lfsr = RegInit(0x1234567887654321L.U(wide.W)) // random initial value based on simulation seed
    val xor = lfsr(0) ^ lfsr(1) ^ lfsr(3) ^ lfsr(4)
    when (increment) {
      lfsr := Mux(lfsr === 0.U, 1.U, Cat(xor, lfsr(wide-1,1)))
    }
    lfsr
  }
}